{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Relay 函数级 Pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm import relay\n",
    "\n",
    "@relay.transform.function_pass(opt_level=1)\n",
    "class TestReplaceFunc:\n",
    "    \"\"\"简单的测试函数，将一个参数替换为另一个参数。\"\"\"\n",
    "\n",
    "    def __init__(self, new_func):\n",
    "        self.new_func = new_func\n",
    "\n",
    "    def transform_function(self, func, mod, ctx):\n",
    "        innerstr1 = \"=\"*40\n",
    "        innerstr2 = \"*\"*40\n",
    "        des = f\"func:\\n{innerstr1}\\n{func}\\n{innerstr2}\\n\"\n",
    "        des += f\"mod\\n{innerstr1}:\\n{mod}\\n{innerstr2}\\n\"\n",
    "        des += f\"ctx:\\n{innerstr1}\\n{ctx}\\n\"\n",
    "        print(des)\n",
    "        return self.new_func"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "func:\n",
      "========================================\n",
      "fn (%x: Tensor[(10, 20), float32]) {\n",
      "  log(%x)\n",
      "}\n",
      "****************************************\n",
      "mod\n",
      "========================================:\n",
      "def @main(%x: Tensor[(10, 20), float32]) {\n",
      "  log(%x)\n",
      "}\n",
      "\n",
      "****************************************\n",
      "ctx:\n",
      "========================================\n",
      "Pass context information: \n",
      "\topt_level: 2\n",
      "\trequired passes: []\n",
      "\tdisabled passes: []\n",
      "\tconfig: {}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "x = relay.var(\"x\", shape=(10, 20))\n",
    "f1 = relay.Function([x], x)\n",
    "f2 = relay.Function([x], relay.log(x))\n",
    "fpass = TestReplaceFunc(f1)\n",
    "assert fpass.info.opt_level == 1\n",
    "assert fpass.info.name == \"TestReplaceFunc\"\n",
    "mod = tvm.IRModule.from_expr(f2)\n",
    "mod = fpass(mod)\n",
    "# wrap in expr\n",
    "mod2 = tvm.IRModule.from_expr(f1)\n",
    "mod2 = tvm.relay.transform.InferType()(mod2)\n",
    "assert tvm.ir.structural_equal(mod[\"main\"], mod2[\"main\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "也可以直接装饰函数："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "@relay.transform.function_pass(opt_level=1)\n",
    "def transform(expr, mod, ctx):\n",
    "    ..."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.4 ('tvmx': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "e579259ee6098e2b9319de590d145b4b096774fe457bdf04260e3ba5c171e887"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
